-
Notifications
You must be signed in to change notification settings - Fork 29k
[MLLIB] SPARK-1547: Add Gradient Boosting to MLlib #2607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test FAILed. |
|
@manishamde Thanks for the WIP PR! About classification, what points need to be discussed? Why is it more difficult to figure out than regression? (Also, I personally am not a big fan of the name "deviance" even though it is used in sklearn and in Friedman's paper. I prefer more descriptive names like LogLoss.) Also, will this be generalized to support weighted weak hypotheses, common in most boosting algorithms? For the final Model produced, should we use the same class for both random forests and gradient boosting? It could be a TreeEnsemble model (to be generalized later to a WeightedEnsemble model). |
|
@epahomov If you or your student are able to take a look at this, I'm sure @manishamde would appreciate it. This PR will hopefully be generalized to include Classification. It's nice in that it has infrastructure for multiple losses. Thank you! |
|
I meant multi-class classification. As you pointed out, binary classification should be similar to the regression case but I am not sure one can handle multi-class classification with one tree. We might have to resort to a one-vs-all strategy. I also agree with you on the naming convention -- log loss or negative binomial log likehood are better names. Yes, I plan to handle weighted weak hypothesis. In fact, I needed it for something like AdaBoost and had to remove it before submitting this PR. Do you think it makes sense to do it along with this PR or do it in the subsequent AdaBoost PR? I agree about the WeightedEnsemble model. Let me add it to the TODO list. |
|
@manishamde Weighted weak hypotheses: I am OK if this initial PR does not include weights, but then weights should be prioritized for the next update. For the WeightedEnsemble, that generalization could be part of this PR or a follow-up. Once this is ready, I'll be happy to help with testing (e.g., to set checkpointing intervals and general performance). Thanks! |
|
@jkbradley error-correcting codes will be a good option to support though we should also have a generic one-vs-all classifier. Yes, weight support will definitely be a part of the adaboost PR. Let's discuss the WeightedEnsemble as we get close to completing the PR. Thanks for helping with the testing. I am currently implementing the TreeRdd caching and subsampling without replacement. After that, we can start testing in parallel along with further code development. |
|
Sounds good! |
|
QA tests have started for PR 2607 at commit
|
|
QA tests have finished for PR 2607 at commit
|
|
Test FAILed. |
|
QA tests have started for PR 2607 at commit
|
|
QA tests have finished for PR 2607 at commit
|
|
Test FAILed. |
|
QA tests have started for PR 2607 at commit
|
|
I have added stochastic gradient boosting by adding code for subsampling without replacement. |
|
Here is an interesting design discussion: For trees and RFs, we convert Here are a few approaches we can take to : I have implemented (1) but I think (2) will be worthwhile to try. Any suggestions? |
|
QA tests have finished for PR 2607 at commit
|
|
Test FAILed. |
|
@manishamde About the 2 caching options, I agree with your decision to do (1) first. It would be nice to try (2) later on (another PR?), but I don't think it is too high-priority. Perhaps we can eventually have learning algs provide convertDatasetToInternalFormat() and predictUsingInternalFormat() methods (with less verbose names), once the standard API is in place. |
|
@jkbradley Cool. I am sure we will see a definitely performance gain once we implement support for (2) once we have a standard API. |
|
Test PASSed. |
|
Test build #22536 has finished for PR 2607 at commit
|
|
Test PASSed. |
|
Thanks for the updates; I'll take a look. I think that it will be very important to include checkpointing, but I am OK with adding it later on. (Since boosting is sequential, I could imagine it running for much longer than bagging/forest algorithms, so protecting against driver failure will be important.) |
|
@jkbradley I agree with protection against driver failure for long sequential operations. However, in this case we will just be checkpointing partial models rather than the intermediate datasets similar to other iterative algorithms such as LR. Look forward to your feedback on the new logic. |
|
True, perhaps we'll need to checkpoint not just the labels but also the data itself for Spark to know how to resume training. Postponing checkpointing seems like a good idea for now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think toString() should print the full model. toString should be concise, and toDebugString should print the full model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do.
|
@manishamde The logic looks better (especially since you caught the learningRate bug!). After the API update (train*, BoostingStrategy, and making AbsoluteError and LogLoss private), I think this will be ready. |
|
@jkbradley Thanks for the confirmation! I will now proceed to finish the rest of the tasks -- should be straightforward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be learningRate too, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the learning rate is applied after the first model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the Friedman paper, the first "model" is just the average label (for squared error). I think it's fine to keep it as is; that way, running for just 1 iteration will behave reasonably.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup.
|
Test build #22582 has started for PR 2607 at commit
|
|
Test build #22582 has finished for PR 2607 at commit
|
|
Test PASSed. |
|
Test build #22596 has started for PR 2607 at commit
|
|
@jkbradley I cleaned up the public API based on our discussion. Going with a nested structure where we have to specify the weak learner parameters separately is cleaner but it puts the onus on us to write very good documentation. I am tempted to keep AbsoluteError and LogLoss as is with the appropriate caveats in the documentation. A regression tree with mean prediction at the terminal nodes it not the best approximation (as pointed out by the TreeBoost paper) but it's not a bad one either. After all, we are just making approximations of the gradient at each step. Moreover, other weak learning algorithms (for example LR) will be hard to tailor towards each specific loss function. Thoughts? |
|
True, it's a good point about LR. OK, let's keep them with caveats, but hopefully run some tests to make sure they seem to be working. I'll make a pass tomorrow morning; thanks for the updates! |
|
Test build #22596 has finished for PR 2607 at commit
|
|
Test PASSed. |
|
@manishamde LGTM! Thanks for updating the Strategy. I think this is ready to be merged, though I still plan to update the train* methods to eliminate the ones taking lots of parameters. In particular, I plan to:
Thanks very much for contributing GBT! It's a big step forward for MLlib. CC: @mengxr |
|
Thanks. Sounds good to me. I tried to use the builder pattern to help for the Java use case but I guess we can handle it separately. |
|
@mengxr Could we get this merged? :-) |
|
I've merged this into master. Thanks @manishamde for contributing and @codedeft and @jkbradley for review! |
Given the popular demand for gradient boosting and AdaBoost in MLlib, I am creating a WIP branch for early feedback on gradient boosting with AdaBoost to follow soon after this PR is accepted. This is based on work done along with @hirakendu that was pending due to decision tree optimizations and random forests work.
Ideally, boosting algorithms should work with any base learners. This will soon be possible once the MLlib API is finalized -- we want to ensure we use a consistent interface for the underlying base learners. In the meantime, this PR uses decision trees as base learners for the gradient boosting algorithm. The current PR allows "pluggable" loss functions and provides least squares error and least absolute error by default.
Here is the task list:
Future work:
cc: @jkbradley @hirakendu @mengxr @etrain @atalwalkar @chouqin